use std::collections::HashMap;
use std::fs::create_dir_all;
use std::ops::Not;
use std::path::Path;
use std::sync::Arc;

use anyhow::ensure;
use containerd_shim::api::{
    ConnectRequest, ConnectResponse, CreateTaskRequest, CreateTaskResponse, DeleteRequest, Empty,
    KillRequest, ShutdownRequest, StartRequest, StartResponse, StateRequest, StateResponse,
    StatsRequest, StatsResponse, WaitRequest, WaitResponse,
};
use containerd_shim::error::Error as ShimError;
use containerd_shim::protos::events::task::{TaskCreate, TaskDelete, TaskExit, TaskIO, TaskStart};
use containerd_shim::protos::shim::shim_ttrpc::Task;
use containerd_shim::protos::types::task::Status;
use containerd_shim::util::IntoOption;
use containerd_shim::{DeleteResponse, TtrpcContext, TtrpcResult};
use futures::FutureExt as _;
use log::debug;
use oci_spec::runtime::Spec;
use prost::Message;
use protobuf::well_known_types::any::Any;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
#[cfg(feature = "opentelemetry")]
use tracing_opentelemetry::OpenTelemetrySpanExt as _;

#[cfg(feature = "opentelemetry")]
use super::otel::extract_context;
use crate::sandbox::async_utils::AmbientRuntime as _;
use crate::sandbox::instance::{Instance, InstanceConfig};
use crate::sandbox::shim::events::{EventSender, RemoteEventSender, ToTimestamp};
use crate::sandbox::shim::instance_data::InstanceData;
use crate::sandbox::sync::WaitableCell;
use crate::sandbox::{Error, Result, oci};
use crate::sys::metrics::get_metrics;

#[cfg(test)]
mod tests;

/// containerd runtime options
#[derive(Message, Clone, PartialEq)]
struct Options {
    #[prost(string)]
    type_url: String,
    #[prost(string)]
    config_path: String,
    #[prost(string)]
    config_body: String,
}

/// This is generated by decoding the `options` field of a `CreateTaskRequest` to get an `Options` struct,
/// interpreting the `config_body` field as TOML,
/// and deserializing it.
#[derive(Serialize, Deserialize, Default, Clone, PartialEq, Debug)]
pub struct Config {
    /// Enables systemd cgroup.
    #[serde(alias = "SystemdCgroup")]
    pub systemd_cgroup: bool,
}

impl Config {
    fn get_from_options(options: Option<&Any>) -> anyhow::Result<Self> {
        let Some(opts) = options else {
            return Ok(Default::default());
        };

        ensure!(
            opts.type_url == "runtimeoptions.v1.Options",
            "Invalid options type {}",
            opts.type_url
        );

        let opts = Options::decode(opts.value.as_slice())?;

        let config = toml::from_str(opts.config_body.as_str())
            .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?;

        Ok(config)
    }
}

type LocalInstances<T> = RwLock<HashMap<String, Arc<InstanceData<T>>>>;

/// Local implements the Task service for a containerd shim.
/// It defers all task operations to the `Instance` implementation.
pub struct Local<T: Instance + Send + Sync, E: EventSender = RemoteEventSender> {
    pub(super) instances: LocalInstances<T>,
    events: E,
    exit: WaitableCell<()>,
    namespace: String,
    containerd_address: String,
}

impl<T: Instance + Send + Sync, E: EventSender> Local<T, E> {
    /// Creates a new local task service.
    #[cfg_attr(
        feature = "tracing",
        tracing::instrument(skip(events, exit), level = "Debug")
    )]
    pub fn new(
        events: E,
        exit: WaitableCell<()>,
        namespace: impl AsRef<str> + std::fmt::Debug,
        containerd_address: impl AsRef<str> + std::fmt::Debug,
    ) -> Self {
        let instances = RwLock::default();
        let namespace = namespace.as_ref().to_string();
        let containerd_address = containerd_address.as_ref().to_string();
        Self {
            instances,
            events,
            exit,
            namespace,
            containerd_address,
        }
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    pub(super) async fn get_instance(&self, id: &str) -> Result<Arc<InstanceData<T>>> {
        let instance = self.instances.read().await.get(id).cloned();
        instance.ok_or_else(|| Error::NotFound(id.to_string()))
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn has_instance(&self, id: &str) -> bool {
        self.instances.read().await.contains_key(id)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn is_empty(&self) -> bool {
        self.instances.read().await.is_empty()
    }
}

// These are the same functions as in Task, but without the TtrcpContext, which is useful for testing
impl<T: Instance + Send + Sync, E: EventSender> Local<T, E> {
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn task_create(&self, req: CreateTaskRequest) -> Result<CreateTaskResponse> {
        let config = Config::get_from_options(req.options.as_ref())
            .map_err(|err| Error::InvalidArgument(format!("invalid shim options: {err}")))?;

        if !req.checkpoint().is_empty() || !req.parent_checkpoint().is_empty() {
            return Err(ShimError::Unimplemented("checkpoint is not supported".to_string()).into());
        }

        if req.terminal {
            return Err(Error::InvalidArgument(
                "terminal is not supported".to_string(),
            ));
        }

        if self.has_instance(&req.id).await {
            return Err(Error::AlreadyExists(req.id));
        }

        let mut spec = Spec::load(Path::new(&req.bundle).join("config.json"))
            .map_err(|err| Error::InvalidArgument(format!("could not load runtime spec: {err}")))?;

        spec.canonicalize_rootfs(req.bundle()).map_err(|err| {
            ShimError::InvalidArgument(format!("could not canonicalize rootfs: {}", err))
        })?;

        let rootfs = spec
            .root()
            .as_ref()
            .ok_or_else(|| Error::InvalidArgument("rootfs is not set in runtime spec".to_string()))?
            .path();

        let _ = create_dir_all(rootfs);
        let rootfs_mounts = req.rootfs().to_vec();
        if !rootfs_mounts.is_empty() {
            for m in rootfs_mounts {
                let _mount_type = m.type_().none_if(|&x| x.is_empty());
                let _source = m.source.as_str().none_if(|&x| x.is_empty());

                #[cfg(unix)]
                containerd_shim::mount::mount_rootfs(
                    _mount_type,
                    _source,
                    &m.options.to_vec(),
                    rootfs,
                )?;
            }
        }

        let cfg = InstanceConfig {
            namespace: self.namespace.clone(),
            containerd_address: self.containerd_address.clone(),
            bundle: req.bundle.as_str().into(),
            stdout: req.stdout.as_str().into(),
            stderr: req.stderr.as_str().into(),
            stdin: req.stdin.as_str().into(),
            config,
        };

        // Check if this is a cri container
        let instance = InstanceData::new(req.id(), cfg).await?;

        self.instances
            .write()
            .await
            .insert(req.id().to_string(), Arc::new(instance));

        self.events.send(TaskCreate {
            container_id: req.id,
            bundle: req.bundle,
            rootfs: req.rootfs,
            io: Some(TaskIO {
                stdin: req.stdin,
                stdout: req.stdout,
                stderr: req.stderr,
                ..Default::default()
            })
            .into(),
            ..Default::default()
        });

        debug!("create done");

        // Per the spec, the prestart hook must be called as part of the create operation
        debug!("call prehook before the start");
        oci::setup_prestart_hooks(spec.hooks())?;

        Ok(CreateTaskResponse {
            pid: std::process::id(),
            ..Default::default()
        })
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn task_start(&self, req: StartRequest) -> Result<StartResponse> {
        if req.exec_id().is_empty().not() {
            return Err(ShimError::Unimplemented("exec is not supported".to_string()).into());
        }

        let i = self.get_instance(req.id()).await?;
        let pid = i.start().await?;

        self.events.send(TaskStart {
            container_id: req.id().into(),
            pid,
            ..Default::default()
        });

        let events = self.events.clone();

        let id = req.id().to_string();

        async move {
            let (exit_code, timestamp) = i.wait().await;
            events.send(TaskExit {
                container_id: id.clone(),
                exit_status: exit_code,
                exited_at: Some(timestamp.to_timestamp()).into(),
                pid,
                id,
                ..Default::default()
            });
        }
        .spawn();

        debug!("started: {:?}", req);

        Ok(StartResponse {
            pid,
            ..Default::default()
        })
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn task_kill(&self, req: KillRequest) -> Result<Empty> {
        if !req.exec_id().is_empty() {
            return Err(Error::InvalidArgument("exec is not supported".to_string()));
        }
        self.get_instance(req.id())
            .await?
            .kill(req.signal())
            .await?;
        Ok(Empty::new())
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn task_delete(&self, req: DeleteRequest) -> Result<DeleteResponse> {
        if !req.exec_id().is_empty() {
            return Err(Error::InvalidArgument("exec is not supported".to_string()));
        }

        let i = self.get_instance(req.id()).await?;

        i.delete().await?;

        let pid = i.pid().unwrap_or_default();
        let (exit_code, timestamp) = i.wait().now_or_never().unzip();
        let timestamp = timestamp.map(ToTimestamp::to_timestamp);

        self.instances.write().await.remove(req.id());

        self.events.send(TaskDelete {
            container_id: req.id().into(),
            pid,
            exit_status: exit_code.unwrap_or_default(),
            exited_at: timestamp.clone().into(),
            ..Default::default()
        });

        Ok(DeleteResponse {
            pid,
            exit_status: exit_code.unwrap_or_default(),
            exited_at: timestamp.into(),
            ..Default::default()
        })
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn task_wait(&self, req: WaitRequest) -> Result<WaitResponse> {
        if !req.exec_id().is_empty() {
            return Err(Error::InvalidArgument("exec is not supported".to_string()));
        }

        let i = self.get_instance(req.id()).await?;
        let (exit_code, timestamp) = i.wait().await;

        debug!("wait finishes");
        Ok(WaitResponse {
            exit_status: exit_code,
            exited_at: Some(timestamp.to_timestamp()).into(),
            ..Default::default()
        })
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn task_state(&self, req: StateRequest) -> Result<StateResponse> {
        if !req.exec_id().is_empty() {
            return Err(Error::InvalidArgument("exec is not supported".to_string()));
        }

        let i = self.get_instance(req.id()).await?;
        let pid = i.pid();
        let (exit_code, timestamp) = i.wait().now_or_never().unzip();
        let timestamp = timestamp.map(ToTimestamp::to_timestamp);

        let status = if pid.is_none() {
            Status::CREATED
        } else if exit_code.is_none() {
            Status::RUNNING
        } else {
            Status::STOPPED
        };

        Ok(StateResponse {
            bundle: i.config.bundle.to_string_lossy().to_string(),
            stdin: i.config.stdin.to_string_lossy().to_string(),
            stdout: i.config.stdout.to_string_lossy().to_string(),
            stderr: i.config.stderr.to_string_lossy().to_string(),
            pid: pid.unwrap_or_default(),
            exit_status: exit_code.unwrap_or_default(),
            exited_at: timestamp.into(),
            status: status.into(),
            ..Default::default()
        })
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Debug"))]
    async fn task_stats(&self, req: StatsRequest) -> Result<StatsResponse> {
        let i = self.get_instance(req.id()).await?;
        let pid = i
            .pid()
            .ok_or_else(|| Error::InvalidArgument("task is not running".to_string()))?;

        let metrics = get_metrics(pid)?;

        Ok(StatsResponse {
            stats: Some(metrics).into(),
            ..Default::default()
        })
    }
}

impl<T: Instance + Sync + Send, E: EventSender> Task for Local<T, E> {
    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn create(
        &self,
        _ctx: &TtrpcContext,
        req: CreateTaskRequest,
    ) -> TtrpcResult<CreateTaskResponse> {
        debug!("create: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        Ok(self.task_create(req).block_on()?)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn start(&self, _ctx: &TtrpcContext, req: StartRequest) -> TtrpcResult<StartResponse> {
        debug!("start: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        Ok(self.task_start(req).block_on()?)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn kill(&self, _ctx: &TtrpcContext, req: KillRequest) -> TtrpcResult<Empty> {
        debug!("kill: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        Ok(self.task_kill(req).block_on()?)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn delete(&self, _ctx: &TtrpcContext, req: DeleteRequest) -> TtrpcResult<DeleteResponse> {
        debug!("delete: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        Ok(self.task_delete(req).block_on()?)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn wait(&self, _ctx: &TtrpcContext, req: WaitRequest) -> TtrpcResult<WaitResponse> {
        debug!("wait: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        let span_exporter = {
            use tracing::{Level, Span, span};
            let parent_span = Span::current();
            parent_span.set_parent(extract_context(&_ctx.metadata));

            // This future never completes as it runs an infinite loop.
            // It will stop executing when dropped.
            // We need to keep this future's lifetime tied to this
            // method's lifetime.
            // This means we shouldn't tokio::spawn it, but rather
            // tokio::select! it inside of this async method.
            async move {
                loop {
                    let current_span =
                        span!(parent: &parent_span, Level::INFO, "task wait 60s interval");
                    let _enter = current_span.enter();
                    tokio::time::sleep(std::time::Duration::from_secs(60)).await;
                }
            }
        };

        #[cfg(not(feature = "opentelemetry"))]
        let span_exporter = std::future::pending::<()>();

        let res = async {
            tokio::select! {
                _ = span_exporter => unreachable!(),
                res = self.task_wait(req) => res,
            }
        }
        .block_on()?;

        Ok(res)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn connect(&self, _ctx: &TtrpcContext, req: ConnectRequest) -> TtrpcResult<ConnectResponse> {
        debug!("connect: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        let i = self.get_instance(req.id()).block_on()?;
        let shim_pid = std::process::id();
        let task_pid = i.pid().unwrap_or_default();
        Ok(ConnectResponse {
            shim_pid,
            task_pid,
            ..Default::default()
        })
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn state(&self, _ctx: &TtrpcContext, req: StateRequest) -> TtrpcResult<StateResponse> {
        debug!("state: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        Ok(self.task_state(req).block_on()?)
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn shutdown(&self, _ctx: &TtrpcContext, _: ShutdownRequest) -> TtrpcResult<Empty> {
        debug!("shutdown");

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        if self.is_empty().block_on() {
            let _ = self.exit.set(());
        }
        Ok(Empty::new())
    }

    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))]
    fn stats(&self, _ctx: &TtrpcContext, req: StatsRequest) -> TtrpcResult<StatsResponse> {
        debug!("stats: {:?}", req);

        #[cfg(feature = "opentelemetry")]
        tracing::Span::current().set_parent(extract_context(&_ctx.metadata));

        Ok(self.task_stats(req).block_on()?)
    }
}
